import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.utils
import math
import numpy as np
import av
from io import BytesIO
from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io

class EmptyLTXVLatentVideo(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="EmptyLTXVLatentVideo",
            category="latent/video/ltxv",
            inputs=[
                io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
                io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
                io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
                io.Int.Input("batch_size", default=1, min=1, max=4096),
            ],
            outputs=[
                io.Latent.Output(),
            ],
        )

    @classmethod
    def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
        latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
        return io.NodeOutput({"samples": latent})

    generate = execute  # TODO: remove

class LTXVImgToVideo(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LTXVImgToVideo",
            category="conditioning/video_models",
            inputs=[
                io.Conditioning.Input("positive"),
                io.Conditioning.Input("negative"),
                io.Vae.Input("vae"),
                io.Image.Input("image"),
                io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
                io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
                io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
                io.Int.Input("batch_size", default=1, min=1, max=4096),
                io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
            ],
            outputs=[
                io.Conditioning.Output(display_name="positive"),
                io.Conditioning.Output(display_name="negative"),
                io.Latent.Output(display_name="latent"),
            ],
        )

    @classmethod
    def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput:
        pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
        encode_pixels = pixels[:, :, :, :3]
        t = vae.encode(encode_pixels)

        latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
        latent[:, :, :t.shape[2]] = t

        conditioning_latent_frames_mask = torch.ones(
            (batch_size, 1, latent.shape[2], 1, 1),
            dtype=torch.float32,
            device=latent.device,
        )
        conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength

        return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})

    generate = execute  # TODO: remove


def conditioning_get_any_value(conditioning, key, default=None):
    for t in conditioning:
        if key in t[1]:
            return t[1][key]
    return default


def get_noise_mask(latent):
    noise_mask = latent.get("noise_mask", None)
    latent_image = latent["samples"]
    if noise_mask is None:
        batch_size, _, latent_length, _, _ = latent_image.shape
        noise_mask = torch.ones(
            (batch_size, 1, latent_length, 1, 1),
            dtype=torch.float32,
            device=latent_image.device,
        )
    else:
        noise_mask = noise_mask.clone()
    return noise_mask

def get_keyframe_idxs(cond):
    keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
    if keyframe_idxs is None:
        return None, 0
    num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
    return keyframe_idxs, num_keyframes

class LTXVAddGuide(io.ComfyNode):
    NUM_PREFIX_FRAMES = 2
    PATCHIFIER = SymmetricPatchifier(1)

    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LTXVAddGuide",
            category="conditioning/video_models",
            inputs=[
                io.Conditioning.Input("positive"),
                io.Conditioning.Input("negative"),
                io.Vae.Input("vae"),
                io.Latent.Input("latent"),
                io.Image.Input(
                    "image",
                    tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
                            "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
                ),
                io.Int.Input(
                    "frame_idx",
                    default=0,
                    min=-9999,
                    max=9999,
                    tooltip="Frame index to start the conditioning at. "
                            "For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. "
                            "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
                            "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
                ),
                io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
            ],
            outputs=[
                io.Conditioning.Output(display_name="positive"),
                io.Conditioning.Output(display_name="negative"),
                io.Latent.Output(display_name="latent"),
            ],
        )

    @classmethod
    def encode(cls, vae, latent_width, latent_height, images, scale_factors):
        time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
        images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
        pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
        encode_pixels = pixels[:, :, :, :3]
        t = vae.encode(encode_pixels)
        return encode_pixels, t

    @classmethod
    def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
        time_scale_factor, _, _ = scale_factors
        _, num_keyframes = get_keyframe_idxs(cond)
        latent_count = latent_length - num_keyframes
        frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
        if guide_length > 1 and frame_idx != 0:
            frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1 # frame index - 1 must be divisible by 8 or frame_idx == 0

        latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor

        return frame_idx, latent_idx

    @classmethod
    def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
        keyframe_idxs, _ = get_keyframe_idxs(cond)
        _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
        pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0)  # we need the causal fix only if we're placing the new latents at index 0
        pixel_coords[:, 0] += frame_idx
        if keyframe_idxs is None:
            keyframe_idxs = pixel_coords
        else:
            keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
        return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})

    @classmethod
    def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
        _, latent_idx = cls.get_latent_index(
            cond=positive,
            latent_length=latent_image.shape[2],
            guide_length=guiding_latent.shape[2],
            frame_idx=frame_idx,
            scale_factors=scale_factors,
        )
        noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0

        positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
        negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)

        mask = torch.full(
            (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
            1.0 - strength,
            dtype=noise_mask.dtype,
            device=noise_mask.device,
        )

        latent_image = torch.cat([latent_image, guiding_latent], dim=2)
        noise_mask = torch.cat([noise_mask, mask], dim=2)
        return positive, negative, latent_image, noise_mask

    @classmethod
    def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength):
        cond_length = guiding_latent.shape[2]
        assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."

        mask = torch.full(
            (noise_mask.shape[0], 1, cond_length, 1, 1),
            1.0 - strength,
            dtype=noise_mask.dtype,
            device=noise_mask.device,
        )

        latent_image = latent_image.clone()
        noise_mask = noise_mask.clone()

        latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
        noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask

        return latent_image, noise_mask

    @classmethod
    def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
        scale_factors = vae.downscale_index_formula
        latent_image = latent["samples"]
        noise_mask = get_noise_mask(latent)

        _, _, latent_length, latent_height, latent_width = latent_image.shape
        image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)

        frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
        assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."

        num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])

        positive, negative, latent_image, noise_mask = cls.append_keyframe(
            positive,
            negative,
            frame_idx,
            latent_image,
            noise_mask,
            t[:, :, :num_prefix_frames],
            strength,
            scale_factors,
        )

        latent_idx += num_prefix_frames

        t = t[:, :, num_prefix_frames:]
        if t.shape[2] == 0:
            return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})

        latent_image, noise_mask = cls.replace_latent_frames(
            latent_image,
            noise_mask,
            t,
            latent_idx,
            strength,
        )

        return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})

    generate = execute  # TODO: remove


class LTXVCropGuides(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LTXVCropGuides",
            category="conditioning/video_models",
            inputs=[
                io.Conditioning.Input("positive"),
                io.Conditioning.Input("negative"),
                io.Latent.Input("latent"),
            ],
            outputs=[
                io.Conditioning.Output(display_name="positive"),
                io.Conditioning.Output(display_name="negative"),
                io.Latent.Output(display_name="latent"),
            ],
        )

    @classmethod
    def execute(cls, positive, negative, latent) -> io.NodeOutput:
        latent_image = latent["samples"].clone()
        noise_mask = get_noise_mask(latent)

        _, num_keyframes = get_keyframe_idxs(positive)
        if num_keyframes == 0:
            return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)

        latent_image = latent_image[:, :, :-num_keyframes]
        noise_mask = noise_mask[:, :, :-num_keyframes]

        positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
        negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})

        return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})

    crop = execute  # TODO: remove


class LTXVConditioning(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LTXVConditioning",
            category="conditioning/video_models",
            inputs=[
                io.Conditioning.Input("positive"),
                io.Conditioning.Input("negative"),
                io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
            ],
            outputs=[
                io.Conditioning.Output(display_name="positive"),
                io.Conditioning.Output(display_name="negative"),
            ],
        )

    @classmethod
    def execute(cls, positive, negative, frame_rate) -> io.NodeOutput:
        positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
        negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
        return io.NodeOutput(positive, negative)


class ModelSamplingLTXV(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="ModelSamplingLTXV",
            category="advanced/model",
            inputs=[
                io.Model.Input("model"),
                io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
                io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
                io.Latent.Input("latent", optional=True),
            ],
            outputs=[
                io.Model.Output(),
            ],
        )

    @classmethod
    def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput:
        m = model.clone()

        if latent is None:
            tokens = 4096
        else:
            tokens = math.prod(latent["samples"].shape[2:])

        x1 = 1024
        x2 = 4096
        mm = (max_shift - base_shift) / (x2 - x1)
        b = base_shift - mm * x1
        shift = (tokens) * mm + b

        sampling_base = comfy.model_sampling.ModelSamplingFlux
        sampling_type = comfy.model_sampling.CONST

        class ModelSamplingAdvanced(sampling_base, sampling_type):
            pass

        model_sampling = ModelSamplingAdvanced(model.model.model_config)
        model_sampling.set_parameters(shift=shift)
        m.add_object_patch("model_sampling", model_sampling)

        return io.NodeOutput(m)


class LTXVScheduler(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LTXVScheduler",
            category="sampling/custom_sampling/schedulers",
            inputs=[
                io.Int.Input("steps", default=20, min=1, max=10000),
                io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
                io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
                io.Boolean.Input(
                    id="stretch",
                    default=True,
                    tooltip="Stretch the sigmas to be in the range [terminal, 1].",
                ),
                io.Float.Input(
                    id="terminal",
                    default=0.1,
                    min=0.0,
                    max=0.99,
                    step=0.01,
                    tooltip="The terminal value of the sigmas after stretching.",
                ),
                io.Latent.Input("latent", optional=True),
            ],
            outputs=[
                io.Sigmas.Output(),
            ],
        )

    @classmethod
    def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput:
        if latent is None:
            tokens = 4096
        else:
            tokens = math.prod(latent["samples"].shape[2:])

        sigmas = torch.linspace(1.0, 0.0, steps + 1)

        x1 = 1024
        x2 = 4096
        mm = (max_shift - base_shift) / (x2 - x1)
        b = base_shift - mm * x1
        sigma_shift = (tokens) * mm + b

        power = 1
        sigmas = torch.where(
            sigmas != 0,
            math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
            0,
        )

        # Stretch sigmas so that its final value matches the given terminal value.
        if stretch:
            non_zero_mask = sigmas != 0
            non_zero_sigmas = sigmas[non_zero_mask]
            one_minus_z = 1.0 - non_zero_sigmas
            scale_factor = one_minus_z[-1] / (1.0 - terminal)
            stretched = 1.0 - (one_minus_z / scale_factor)
            sigmas[non_zero_mask] = stretched

        return io.NodeOutput(sigmas)

def encode_single_frame(output_file, image_array: np.ndarray, crf):
    container = av.open(output_file, "w", format="mp4")
    try:
        stream = container.add_stream(
            "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
        )
        stream.height = image_array.shape[0]
        stream.width = image_array.shape[1]
        av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
            format="yuv420p"
        )
        container.mux(stream.encode(av_frame))
        container.mux(stream.encode())
    finally:
        container.close()


def decode_single_frame(video_file):
    container = av.open(video_file)
    try:
        stream = next(s for s in container.streams if s.type == "video")
        frame = next(container.decode(stream))
    finally:
        container.close()
    return frame.to_ndarray(format="rgb24")


def preprocess(image: torch.Tensor, crf=29):
    if crf == 0:
        return image

    image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
    with BytesIO() as output_file:
        encode_single_frame(output_file, image_array, crf)
        video_bytes = output_file.getvalue()
    with BytesIO(video_bytes) as video_file:
        image_array = decode_single_frame(video_file)
    tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
    return tensor


class LTXVPreprocess(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="LTXVPreprocess",
            category="image",
            inputs=[
                io.Image.Input("image"),
                io.Int.Input(
                    id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
                ),
            ],
            outputs=[
                io.Image.Output(display_name="output_image"),
            ],
        )

    @classmethod
    def execute(cls, image, img_compression) -> io.NodeOutput:
        output_images = []
        for i in range(image.shape[0]):
            output_images.append(preprocess(image[i], img_compression))
        return io.NodeOutput(torch.stack(output_images))

    preprocess = execute  # TODO: remove

class LtxvExtension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            EmptyLTXVLatentVideo,
            LTXVImgToVideo,
            ModelSamplingLTXV,
            LTXVConditioning,
            LTXVScheduler,
            LTXVAddGuide,
            LTXVPreprocess,
            LTXVCropGuides,
        ]


async def comfy_entrypoint() -> LtxvExtension:
    return LtxvExtension()
